library(here)
library(tidyverse)
library(scales)
library(redist)
library(sf)
library(patchwork)

DEM = "#0064B0"
GOP = "#A0442C"
source(here("R/00_custom_functions.R"))

path = here("data/sims_partisan_sum.rds")

is_ref = function(x) str_detect(x, "cd") | str_detect(x, "<init") | str_detect(x, "sld_up")

# Load plans and simulations -----------------------------------------------------
if (!file.exists(path)) {
    pa_shp = read_rds(here("data/PA/pa_shp.rds"))
    class(pa_shp) = c("redist_map", class(pa_shp))

    sc_shp = read_rds(here("data/SC/sc_shp.rds"))
    class(sc_shp) = c("redist_map", class(sc_shp))

    nc_shp = read_rds(here("data/NC/nc_shp.rds"))
    class(nc_shp) = c("redist_map", class(nc_shp))

    de_shp = read_rds(here("data/DE/de.Rds"))
    class(de_shp) = c("redist_map", class(de_shp))

    clean_plans = function(pl) {
        as_tibble(pl) %>%
            `attr<-`("plans", NULL) %>%
            `attr<-`("prec_pop", NULL) %>%
            `attr<-`("wgt", NULL)
    }
    tally_pa = function(pl) {
        mutate(pl, grp = group_frac(pa_shp, ndv, ndv + nrv)) %>%
            number_by(grp) %>%
            clean_plans()
    }
    tally_sc = function(pl) {
        mutate(pl, grp = group_frac(sc_shp, gov_vest_DEM, gov_vest_DEM + gov_vest_REP)) %>%
            number_by(grp) %>%
            clean_plans()
    }
    tally_nc = function(pl) {
        mutate(pl, grp = group_frac(nc_shp, EL12G_GV_D, EL12G_GV_D + EL12G_GV_R)) %>%
            number_by(grp) %>%
            clean_plans()
    }
    tally_de = function(pl) {
        mutate(pl, grp = group_frac(de_shp, G20PREDBID, G20PREDBID + G20PRERTRU)) %>%
            number_by(grp) %>%
            clean_plans()
    }

    pa_plans = list(
        das04 = read_rds(here("data/PA/sim_das04_cty_001_10k.rds")) %>% tally_pa(),
        das12 = read_rds(here("data/PA/sim_das12_cty_001_10k.rds")) %>% tally_pa(),
        das19 = read_rds(here("data/PA/sim_das19_cty_001_10k.rds")) %>% tally_pa(),
        orig = read_rds(here("data/PA/sim_orig_cty_001_10k.rds")) %>% tally_pa()
    )

    sc_cd_plans = list(
        das04 = read_rds(here("data/SC/sim/CD/plans_da04.rds")) %>% tally_sc(),
        das12 = read_rds(here("data/SC/sim/CD/plans_da12.rds")) %>% tally_sc(),
        das19 = read_rds(here("data/SC/sim/CD/plans_da19.rds")) %>% tally_sc(),
        orig = read_rds(here("data/SC/sim/CD/plans_orig.rds")) %>% tally_sc()
    )

    sc_hd_plans = list(
        das04 = read_rds(here("data/SC/sim/HD_ms-parallel/plans_da04.rds")) %>% tally_sc(),
        das12 = read_rds(here("data/SC/sim/HD_ms-parallel/plans_da12.rds")) %>% tally_sc(),
        das19 = read_rds(here("data/SC/sim/HD_ms-parallel/plans_da19.rds")) %>% tally_sc(),
        orig = read_rds(here("data/SC/sim/HD_ms-parallel/plans_orig.rds")) %>% tally_sc()
    )

    nc_cd_plans = list(
        das04 = read_rds(here("data/NC/sim/CD/plans_da04.rds")) %>% tally_nc(),
        das12 = read_rds(here("data/NC/sim/CD/plans_da12.rds")) %>% tally_nc(),
        das19 = read_rds(here("data/NC/sim/CD/plans_da19.rds")) %>% tally_nc(),
        orig = read_rds(here("data/NC/sim/CD/plans_orig.rds")) %>% tally_nc()
    )

    de_plans = list(
        das04 = read_rds(here("data/DE/sim/plans_v4.rds")) %>% tally_de(),
        das12 = read_rds(here("data/DE/sim/plans_v12.rds")) %>% tally_de(),
        das19 = read_rds(here("data/DE/sim/plans_v19.rds")) %>% tally_de(),
        orig = read_rds(here("data/DE/sim/plans_cen.rds")) %>% tally_de()
    )

    rm(pa_shp, nc_shp, sc_shp, de_shp)


    # Plot --------------------------------------------------------------------

    pcts = bind_rows(
        pa_cd = bind_rows(pa_plans, .id = "source"),
        sc_cd = bind_rows(sc_cd_plans, .id = "source"),
        nc_cd = bind_rows(nc_cd_plans, .id = "source"),
        sc_hd = bind_rows(sc_hd_plans, .id = "source"),
        de_sd = bind_rows(de_plans, .id = "source"),
        .id="state"
    )
    rm(pa_plans, sc_hd_plans, sc_cd_plans, nc_cd_plans, de_plans)

    write_rds(pcts, path, compress = "xz")
}

if (file.exists(path)) {
    pcts = read_rds(path)
}

# Warning: Some districts may have been dropped. This will prevent summary statistics from working correctly.

# Remove ref
pcts_grp0 = pcts %>%
    filter(district == as.integer(district)) %>%
    group_by(state, source, district) %>%
    mutate(ref_grp = grp[is_ref(draw)][1]) %>%
    filter(!is_ref(draw)) %>%
    ungroup()

# group together SC HDs, and restack
pcts_sc_hd <- pcts_grp0 %>%
    filter(state == "sc_hd") %>%
    mutate(district_new = 1 + as.integer((district-1)/5)*5)

pcts_grp <- bind_rows(
    filter(pcts_grp0, state != "sc_hd"),
    pcts_sc_hd) %>%
    mutate(district = coalesce(district_new, district))

# Summary statistics ---
pcts_box = pcts_grp %>%
    group_by(state, source, district) %>%
    summarize(med = median(grp - ref_grp),
              q1 = quantile(grp - ref_grp, 0.25),
              q3 = quantile(grp - ref_grp, 0.75),
              low = min(grp - ref_grp),
              high = max(grp - ref_grp),
              .groups = "drop")
rm(pcts)


# patch layout
#' @param data table with all summary stats
#' @param sources which sources to show
patch_boxes_party <- function(data = pcts_box, sources = c("Census 2010", "DAS-12.2", "DAS-4.5")) {
    p_pa    = plot_boxes(data, geo = "pa_cd", title = "Pennsylvania Congressional", show = sources) + labs(x = "Districts, ordered by Democratic share")
    p_sc_cd = plot_boxes(data, geo = "sc_cd", title = "South Carolina Congressional", show = sources) + labs(y = NULL)
    p_nc_cd = plot_boxes(data, geo = "nc_cd", title = "North Carolina Congressional", show = sources)
    p_de    = plot_boxes(data, geo = "de_sd", title = "Delaware State Senate", show = sources)
    p_sc_hd = plot_boxes(data, geo = "sc_hd", grouped = c(by = 4, max = 124), title = "South Carolina State House", show = sources) +
        theme(axis.text.x = element_text(angle = 30, hjust = 1))

    layout =
        "
         AAABB
         CCCCC
         DDDDD
         EEEEE
        "
    p_nc_cd + p_sc_cd + p_sc_hd + p_de + p_pa +
        plot_layout(design = layout, guides = "collect") &
        theme(legend.position = "bottom")
}

# Generate plots (functions in 00_)
gg_12_party <- patch_boxes_party(pcts_box, sources = c("Census 2010", "DAS-12.2", "DAS-4.5"))
gg_19_party <- patch_boxes_party(pcts_box, sources = c("Census 2010", "DAS-19.61"))


# Save plots
ggsave("figs/partisan_boxplots.pdf", gg_12_party, width = 6.5, height = 8.25)
ggsave("figs/partisan_boxplots_DAS19.pdf", gg_19_party, width = 6.5, height = 8.25)
